// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package tests

import (
	"context"
	"fmt"
	"time"

	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/clusterstats"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/grafana"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/option"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestutil"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec"
	"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
	"github.com/cockroachdb/cockroach/pkg/roachprod/install"
	"github.com/cockroachdb/cockroach/pkg/roachprod/prometheus"
	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
	"github.com/cockroachdb/errors"
	"github.com/prometheus/common/model"
	"github.com/stretchr/testify/require"
)

// This test aims to test the behavior of range snapshots with splits and
// excises enabled in the storage engine. It sets up a 3 node cluster, where the
// cluster is pre-populated with about 500GB of data. Then, a foreground kv
// workload is run, and shortly after that, n3 is brought down. Upon restart, n3
// starts to receive large amounts of snapshot data. With excises turned on, it
// is expected that l0 sublevel counts and p99 latencies remain stable.
func registerSnapshotOverloadExcise(r registry.Registry) {
	r.Add(registry.TestSpec{
		Name:             "admission-control/snapshot-overload-excise",
		Owner:            registry.OwnerAdmissionControl,
		Benchmark:        true,
		CompatibleClouds: registry.OnlyGCE,
		Suites:           registry.Suites(registry.Weekly),
		// The test uses a large volume size to ensure high provisioned bandwidth
		// from the cloud provider.
		Cluster: r.MakeClusterSpec(4, spec.CPU(4), spec.WorkloadNode(), spec.VolumeSize(2000)),
		Leases:  registry.MetamorphicLeases,
		Timeout: 12 * time.Hour,
		Run: func(ctx context.Context, t test.Test, c cluster.Cluster) {
			if c.Spec().NodeCount < 4 {
				t.Fatalf("expected at least 4 nodes, found %d", c.Spec().NodeCount)
			}

			envOptions := install.EnvOption{
				// COCKROACH_CONCURRENT_COMPACTIONS is set to 1 since we want to ensure
				// that snapshot ingests don't result in LSM inversion even with a very
				// low compaction rate. With Pebble's IngestAndExcise all the ingested
				// sstables should ingest into L6.
				"COCKROACH_CONCURRENT_COMPACTIONS=1",
				// COCKROACH_RAFT_LOG_TRUNCATION_THRESHOLD is reduced so that there is
				// certainty that the restarted node will be caught up via snapshots,
				// and not via raft log replay.
				fmt.Sprintf("COCKROACH_RAFT_LOG_TRUNCATION_THRESHOLD=%d", 512<<10 /* 512KiB */),
				// COCKROACH_CONCURRENT_SNAPSHOT* is increased so that the rate of
				// snapshot application is high.
				"COCKROACH_CONCURRENT_SNAPSHOT_APPLY_LIMIT=100",
				"COCKROACH_CONCURRENT_SNAPSHOT_SEND_LIMIT=100",
			}

			startOpts := option.NewStartOpts(option.NoBackupSchedule)
			roachtestutil.SetDefaultAdminUIPort(c, &startOpts.RoachprodOpts)
			roachtestutil.SetDefaultSQLPort(c, &startOpts.RoachprodOpts)
			settings := install.MakeClusterSettings(envOptions)
			c.Start(ctx, t.L(), startOpts, settings, c.CRDBNodes())

			db := c.Conn(ctx, t.L(), len(c.CRDBNodes()))
			defer db.Close()

			t.Status(fmt.Sprintf("configuring cluster settings (<%s)", 30*time.Second))
			{
				// Defensive, since admission control is enabled by default.
				setAdmissionControl(ctx, t, c, true)
				// Ensure ingest splits and excises are enabled. (Enabled by default in v24.1+)
				if _, err := db.ExecContext(
					ctx, "SET CLUSTER SETTING kv.snapshot_receiver.excise.enabled = 'true'"); err != nil {
					t.Fatalf("failed to set kv.snapshot_receiver.excise.enabled: %v", err)
				}
				if _, err := db.ExecContext(
					ctx, "SET CLUSTER SETTING storage.ingest_split.enabled = 'true'"); err != nil {
					t.Fatalf("failed to set storage.ingest_split.enabled: %v", err)
				}

				// Set a high rebalance rate.
				if _, err := db.ExecContext(
					ctx, "SET CLUSTER SETTING kv.snapshot_rebalance.max_rate = '256MiB'"); err != nil {
					t.Fatalf("failed to set kv.snapshot_rebalance.max_rate: %v", err)
				}
			}

			// Setup the prometheus instance and client.
			t.Status(fmt.Sprintf("setting up prometheus/grafana (<%s)", 2*time.Minute))
			var statCollector clusterstats.StatCollector
			promCfg := &prometheus.Config{}
			promCfg.WithPrometheusNode(c.WorkloadNode().InstallNodes()[0]).
				WithNodeExporter(c.CRDBNodes().InstallNodes()).
				WithCluster(c.CRDBNodes().InstallNodes()).
				WithGrafanaDashboardJSON(grafana.SnapshotAdmissionControlGrafanaJSON)
			err := c.StartGrafana(ctx, t.L(), promCfg)
			require.NoError(t, err)
			cleanupFunc := func() {
				if err := c.StopGrafana(ctx, t.L(), t.ArtifactsDir()); err != nil {
					t.L().ErrorfCtx(ctx, "Error(s) shutting down prom/grafana %s", err)
				}
			}
			defer cleanupFunc()
			promClient, err := clusterstats.SetupCollectorPromClient(ctx, c, t.L(), promCfg)
			require.NoError(t, err)
			statCollector = clusterstats.NewStatsCollector(ctx, promClient)

			// Initialize the kv database,
			t.Status(fmt.Sprintf("initializing kv dataset (<%s)", 2*time.Hour))
			c.Run(ctx, option.WithNodes(c.WorkloadNode()),
				"./cockroach workload init kv --drop --insert-count=40000000 "+
					"--max-block-bytes=12288 --min-block-bytes=12288 {pgurl:1-3}")

			t.Status(fmt.Sprintf("starting kv workload thread (<%s)", time.Minute))
			m := c.NewMonitor(ctx, c.CRDBNodes())
			m.Go(func(ctx context.Context) error {
				c.Run(ctx, option.WithNodes(c.WorkloadNode()),
					fmt.Sprintf("./cockroach workload run kv --tolerate-errors "+
						"--splits=1000 --histograms=%s/stats.json --read-percent=75 "+
						"--max-rate=600 --max-block-bytes=12288 --min-block-bytes=12288 "+
						"--concurrency=4000 --duration=%s {pgurl:1-2}",
						t.PerfArtifactsDir(), (6*time.Hour).String()))
				return nil
			})

			t.Status(fmt.Sprintf("waiting for data build up (<%s)", time.Hour))
			time.Sleep(time.Hour)

			t.Status(fmt.Sprintf("killing node 3... (<%s)", time.Minute))
			c.Stop(ctx, t.L(), option.DefaultStopOpts(), c.Node(3))

			t.Status(fmt.Sprintf("waiting for increased snapshot data and raft log truncation (<%s)", 2*time.Hour))
			time.Sleep(2 * time.Hour)

			t.Status(fmt.Sprintf("starting node 3... (<%s)", time.Minute))
			c.Start(ctx, t.L(), startOpts, install.MakeClusterSettings(envOptions), c.Node(3))

			t.Status(fmt.Sprintf("waiting for snapshot transfers to finish %s", 2*time.Hour))
			m.Go(func(ctx context.Context) error {
				t.Status(fmt.Sprintf("starting monitoring thread (<%s)", time.Minute))
				getMetricVal := func(query string, label string) (float64, error) {
					point, err := statCollector.CollectPoint(ctx, t.L(), timeutil.Now(), query)
					if err != nil {
						t.L().Errorf("could not query prom %s", err.Error())
						return 0, err
					}
					val := point[label]
					for storeID, v := range val {
						t.L().Printf("%s(store=%s): %f", query, storeID, v.Value)
						// We only assert on the 3rd store.
						if storeID == "3" {
							return v.Value, nil
						}
					}
					// Unreachable.
					panic("unreachable")
				}
				getHistMetricVal := func(query string) (float64, error) {
					at := timeutil.Now()
					fromVal, warnings, err := promClient.Query(ctx, query, at)
					if err != nil {
						return 0, err
					}
					if len(warnings) > 0 {
						return 0, errors.Newf("found warnings querying prometheus: %s", warnings)
					}

					fromVec := fromVal.(model.Vector)
					if len(fromVec) == 0 {
						return 0, errors.Newf("Empty vector result for query %s @ %s (%v)", query, at.Format(time.RFC3339), fromVal)
					}
					return float64(fromVec[0].Value), nil
				}

				// Assert on l0 sublevel count and p99 latencies.
				latencyMetric := divQuery("histogram_quantile(0.99, sum by(le) (rate(sql_service_latency_bucket[2m])))", 1<<20 /* 1ms */)
				const latencyThreshold = 100 // 100ms since the metric is scaled to 1ms above.
				const sublevelMetric = "storage_l0_sublevels"
				const sublevelThreshold = 20
				var l0SublevelCount []float64
				const sampleCountForL0Sublevel = 12
				const collectionIntervalSeconds = 10.0
				// Loop for ~120 minutes.
				const numIterations = int(120 / (collectionIntervalSeconds / 60))
				numErrors := 0
				numSuccesses := 0
				for i := 0; i < numIterations; i++ {
					time.Sleep(collectionIntervalSeconds * time.Second)
					val, err := getHistMetricVal(latencyMetric)
					if err != nil {
						numErrors++
						continue
					}
					if val > latencyThreshold {
						t.Fatalf("sql p99 latency %f exceeded threshold", val)
					}
					val, err = getMetricVal(sublevelMetric, "store")
					if err != nil {
						numErrors++
						continue
					}
					l0SublevelCount = append(l0SublevelCount, val)
					// We want to use the mean of the last 2m of data to avoid short-lived
					// spikes causing failures.
					if len(l0SublevelCount) >= sampleCountForL0Sublevel {
						latestSampleMeanL0Sublevels := getMeanOverLastN(sampleCountForL0Sublevel, l0SublevelCount)
						if latestSampleMeanL0Sublevels > sublevelThreshold {
							t.Fatalf("sub-level mean %f over last %d iterations exceeded threshold", latestSampleMeanL0Sublevels, sampleCountForL0Sublevel)
						}
					}
					numSuccesses++
				}
				t.Status(fmt.Sprintf("done monitoring, errors: %d successes: %d", numErrors, numSuccesses))
				if numErrors > numSuccesses {
					t.Fatalf("too many errors retrieving metrics")
				}
				return nil
			})

			m.Wait()
		},
	})
}
